# version 11.0
# update from 1.0 which is coded by Yuziyang and Dengli
# version 11.0 upgrade by Xianlong Dai, add ITD normalization algorithm
# 2024.07

import os
import dpkt
import copy
import math
import sub_sq as sbq
import itertools
from collections import Counter
import calFE_BA as cf

import numpy as np
import BasicFlowInfo

import BasicUDP_v2 as BasicUdpInfo
import PacketInfo
from TINTUdpFilter import TINTUdpFilter

from scipy.stats import entropy
from scipy.stats import skew, kurtosis 

class MyError(Exception):  
    def __init__(self, message):  
        self.message = message  
    def __str__(self):  
        return self.message  

flowTimeout = 120
activityTimeout = 10

global_data = []

def deal_packet(packet, CurrentFlow, DestIp, flow_info_list):
    tint = None
    # 0 out 1 in
    if packet.dst_ip == DestIp:
        direct = 1
    elif packet.src_ip == DestIp:
        direct = 0
    else:
        return tint

    #超时、syn、Rst和fin-ack结束流
    if CurrentFlow.__contains__(packet.GetfwdFlowId()) or CurrentFlow.__contains__(packet.GetbwdFlowId()):
        if CurrentFlow.__contains__(packet.GetfwdFlowId()):
            id = packet.GetfwdFlowId()
        else:
            id = packet.GetbwdFlowId()
        flow = CurrentFlow[id]
        if (packet.ts - flow.flowStartTime) >= flowTimeout or packet.IsSyn():
            tint = flow.getTint()
            if tint != None:
                flow_info_list.append(CurrentFlow[id])
            del CurrentFlow[id]
            CurrentFlow[packet.GetfwdFlowId()] = BasicFlowInfo.BasicFlow(packet, direct)
        elif packet.IsRst():
            CurrentFlow[id].addpacket(packet, direct)
            tint = CurrentFlow[id].getTint()
            if tint != None:
                flow_info_list.append(CurrentFlow[id])
            del CurrentFlow[id]
        else:
            if flow.FinAck2 == packet.ackNum:
                CurrentFlow[id].addpacket(packet, direct)
                tint = CurrentFlow[id].getTint()
                if tint != None:
                    flow_info_list.append(CurrentFlow[id])
                del CurrentFlow[id]
            else:
                CurrentFlow[id].addpacket(packet, direct)
                tint = CurrentFlow[id].getTint()
                CurrentFlow[id].tint = []

    else:
        if packet.IsRst():
            return tint
        else:
            CurrentFlow[packet.GetfwdFlowId()] = BasicFlowInfo.BasicFlow(packet, direct)
    return tint


basicTintUdpFilter = TINTUdpFilter.get_basic_tint_udp_filter()             # udp-filter + tcp
positiveTintUdpFilter = TINTUdpFilter.get_positive_tint_udp_filter()       # udp + tcp
negativeTintUdpFilter = TINTUdpFilter.get_negative_tint_udp_filter()       # tcp
#udpFilter = basicTintUdpFilter
udpFilter = positiveTintUdpFilter

_AK47_ZSC = 2 + 12 + 4 + 5 + 10

# 解析pcap包
def parse_pcap(pcap, DestIp, big_time_windows, small_time_windows, normal_flag):
    flow_info_list = []
    CurrentFlow = {}
    lastActiveTime = 0

    tcp_tints = []
    UDP = BasicUdpInfo.BasicUdp(DestIp)
    
    last_big_step_ts = 0
    last_small_step_ts = 0
    big_step_dict = {}
    udp_debug_result = {}
    
    for ts, buf in pcap:#tqdm(pcap):
        pk = PacketInfo.Packet(ts, buf)

        # 检查是否有不活跃的流，从当前流字典中删除并提取tint
        if (ts - lastActiveTime) >= activityTimeout:
            del_id = []
            for id in CurrentFlow:
                if CurrentFlow[id].lastPktTime <= lastActiveTime:
                    tint = CurrentFlow[id].getTint()
                    del_id.append(id)
                    if tint != None:
                        flow_info_list.append(CurrentFlow[id])
                        tcp_tints.extend(tint)
            for id in del_id:
                del CurrentFlow[id]
            lastActiveTime = ts

        if pk.IsTcp():
            #处理数据包
            flow_tints = deal_packet(pk, CurrentFlow, DestIp, flow_info_list)
            if flow_tints != None:
                tcp_tints.extend(flow_tints)

        elif pk.IsUdp():
            # 检查 UDP 数据包是否遵循已知的应用层协议进行 RTT 计算
            if udpFilter.filter_debug(pk, udp_debug_result):
                UDP.add_packet(pk)
            pk = None


        if last_big_step_ts == 0 or last_small_step_ts == 0:
            last_big_step_ts = ts
            last_small_step_ts = ts


        # 方案5：
        if ts - last_small_step_ts > small_time_windows:
            if ts - last_small_step_ts > small_time_windows:
                #print('***********************************')
                # 当前小窗口的数据
                udp_tints = UDP.get_tints()
                big_step_dict[str(last_small_step_ts)] = (tcp_tints, udp_tints)
                last_small_step_ts = ts
                tcp_tints = []

                # 获取大窗口数据（过期的剔除）
                big_step_tcp_tints = []
                big_step_udp_tints = []
                old_key = []
                begin_ts = 9900000000
                for key, value in big_step_dict.items():
                    small_step_tcp_tints, small_step_udp_tints = value
                    if ts - float(key) <= big_time_windows:
                        big_step_tcp_tints.extend(small_step_tcp_tints)
                        big_step_udp_tints.extend(small_step_udp_tints)
                        if float(key) < begin_ts:
                            begin_ts = float(key)
                    else:
                        old_key.append(key)
                for key in old_key:
                    big_step_dict.pop(key)

                _tcp_tints = [round(float(i), 6) for i in big_step_tcp_tints]
                _udp_tints = [round(float(i), 6) for i in big_step_udp_tints]

                def f(tcp_tints, udp_tints, normal_flag):
                    
                    if tcp_tints == [] and udp_tints == []:
                        return []
                    elif tcp_tints != [] and udp_tints == []:
                        tcp_tints_re1 = np.add((tcp_tints - np.min(tcp_tints)), 13e-6).tolist()
                        udp_tints_re1 = []
                    elif tcp_tints == [] and udp_tints != []:
                        udp_tints_re1 = np.add((udp_tints - np.min(udp_tints)), 131e-7).tolist()
                        tcp_tints_re1 = []
                    else:
                        udp_min = np.min(udp_tints)
                        tcp_min = np.min(tcp_tints)
                        if udp_min < tcp_min:
                            udp_tints_re1 = np.add((udp_tints - np.min(udp_tints)), 131e-7).tolist()
                            tcp_tints_re1 = np.add((tcp_tints - np.min(udp_tints)), 13e-6).tolist()
                        elif udp_min > tcp_min:
                            udp_tints_re1 = np.add((udp_tints - np.min(tcp_tints)), 131e-7).tolist()
                            tcp_tints_re1 = np.add((tcp_tints - np.min(tcp_tints)), 13e-6).tolist()
                        else:
                            tcp_tints_re1 = np.add((tcp_tints - np.min(tcp_tints)), 13e-6).tolist()
                            udp_tints_re1 = np.add((udp_tints - np.min(udp_tints)), 131e-7).tolist()

                    
                    if normal_flag == 'zsc':
                        tcp_features = cf.calc_for_Zsc(tcp_tints_re1, _AK47_ZSC)
                        # print(AK471)
                        udp_features = cf.calc_for_Zsc(udp_tints_re1, _AK47_ZSC)

                        if tcp_features == [None] * _AK47_ZSC:
                            tcp_tints1 = [13e-6] * 5
                            tcp_features = cf.calc_for_Zsc(tcp_tints1, _AK47_ZSC)
                            all_features = udp_features
                        elif udp_features == [None] * _AK47_ZSC:
                            udp_tints1 = [131e-7] * 5
                            udp_features = cf.calc_for_Zsc(udp_tints1, _AK47_ZSC)
                            all_features = tcp_features
                        else:
                            all_tints = tcp_tints + udp_tints
                            all_features = cf.calc_for_Zsc(all_tints, _AK47_ZSC)
                    

                    global_data.append(tcp_features + udp_features + all_features)
                    
                
                if last_small_step_ts - begin_ts >= big_time_windows - small_time_windows:
                    f(_tcp_tints, _udp_tints, normal_flag)

                


def get_string_between_brackets(filename):
    """
    获取文件名中[]之间的字符串。

    参数：
    filename: 字符串类型，表示文件名。

    返回值：
    如果文件名中包含[]，则返回[]之间的字符串；否则返回空字符串。
    """
    start = filename.find('[')
    end = filename.find(']')
    if start != -1 and end != -1:
        return filename[start+1:end]
    else:
        return ''



def Tint2(pcap_file_dir, label, big_win_time, small_win_time, A, n_flag):
    pcap_file = None
    try:
        #读取pcap文件
        for root, dirs, fnames in os.walk(pcap_file_dir):
            for fname in fnames: 
                DestIp = get_string_between_brackets(fname)
                print(DestIp)
                if ".csv" not in fname:
                    FinalProcess(os.path.join(root, fname), DestIp, label, big_win_time, small_win_time, A, n_flag)
                else:
                    print('******** {}  is not a pcap file! ********'.format(fname))
                    continue
                        
    except KeyError:
        pass
    finally:
        if pcap_file is not None:
            pcap_file.close()
    return None

def TintFile(path, fname, DestIp, big_win_time, small_win_time, normal_flag):
    pcap_file = open(path, 'rb')
    pcap_1 = dpkt.pcap.Reader(pcap_file)
    print('******** {} ********'.format(fname))

    parse_pcap(pcap_1, DestIp, big_win_time, small_win_time, normal_flag)


def FinalProcess(pcap_path, target_ip, label, big_win_time, small_win_time, A, normal_flag):

    # A = r'Z:\ZBA_2406'
    if not os.path.exists(A):  
        os.makedirs(A) 

    
    def f(type : str):
        res = ''
        lis = [5,25,35,90,95]
        for i in lis:#range(5, 96, 5):
            res += f'{type}_{i}%_percentile,'
        return res

    def generate_index_pairs(start_values, end_values):  
        pairs = []  
        for start in start_values:  
            for end in end_values:  
                if (start + end) < max(start_values):  # 确保 start < end  
                    pairs.append((start, start + end))  
        return pairs

    def f2(type): 
        res = ''
        TTSST = [(0,25),(0,50),(0,75),(0,95),(25,50),(25,75),(45,95),(50,75),(25,95),(75,95),(50,95),(0,100)]
        for start, end in TTSST:
            res += f'{type[0]}k_{start}_{end},'  
        return res
    

    if normal_flag == 'zsc':
        #'2 + (31 + 4) + (4 + 5) + 10'
        def f7(type : str):
            cs = f'{type}_range_ori,{type}_range_nor,'
            return cs
        
        def f3(type : str):
            cs = f'{type}_tentr,{type}_hentr,{type}_min_tentr,{type}_min_hentr,{type}_max_tentr,{type}_max_hentr,{type}_range_tentr,{type}_range_hentr,{type}_p50_tentr,{type}_p50_hentr,'
            return cs
        
        def f6(type : str):
            cs = f'{type}_median,{type}_skewness,{type}_kurtosis,{type}_iqr,'
            return cs

        csv_data = f'{f7("tcp")}{f2("tcp")}'
        csv_data += f'{f6("tcp")}{f("tcp")}{f3("tcp")}'
        
        csv_data += f'{f7("udp")}{f2("udp")}'
        csv_data += f'{f6("udp")}{f("udp")}{f3("udp")}'

        csv_data += f'{f7("all")}{f2("all")}'
        csv_data += f'{f6("all")}{f("all")}{f3("all")}'

    else:
        raise MyError("The normalization method must belong to one of the presets!!!") 
    
    csv_data += f'label\n'
    feu_l = copy.deepcopy(csv_data)

    global global_data
    TintFile(pcap_path, pcap_path, target_ip, big_win_time, small_win_time, normal_flag)
    print('lines:', len(global_data))
    lines_num = len(global_data)
    for line in global_data:
        s_line = ''
        for item in line:
            if item != None:
                s_line += str(item) + ','
            else:
                s_line += ','
        csv_data += s_line + str(label) + '\n'
    
    flag = int(feu_l.count(','))
    global_data = []
    dir_path = os.path.dirname(pcap_path)  
 
    pcap_name = os.path.basename(pcap_path)  

    
    dir_path = A + f'\{big_win_time}_{small_win_time}\\1+{str(label)}'

    if not os.path.exists(dir_path):  
        os.makedirs(dir_path)

    new_name = f'{str(big_win_time)}_{str(small_win_time)}_{flag}fe_1+{str(label)}_' + pcap_name.replace('.pcap', '')

    pcap_path_new = os.path.join(dir_path, new_name)
    if lines_num != 0:
        with open(pcap_path_new + '.csv', 'w') as f:
            f.write(csv_data)


if __name__ == '__main__':
    
    pcap_file_dir_list_1 = [
        
        
    ]

    pcap_file_dir_list_2 = [
        
        
    ]

    pcap_file_dir_list_3 = [
        
    ]

    pcap_file_dir_list_4 = [
        
    ]
    
    pcap_file_dir_list_5 = [
        
        
       
    ]

    pcap_file_dir_list_6 = [
        
    ]

    p7 =[
        
    ]

    p8 = [
        
    ]

    pcap_file_dir_dict = {  
        1: pcap_file_dir_list_1,  
        2: pcap_file_dir_list_2,  
        3: pcap_file_dir_list_3,  
        4: pcap_file_dir_list_4,  
        5: pcap_file_dir_list_5,  
        6: pcap_file_dir_list_6,
        7: p7,
        8: p8
    }  
  
    # 获取用户输入  
    inss = input(f"请输入一个数字(1-{len(pcap_file_dir_dict)}):")  
    
    # 尝试将输入转换为整数，并检查它是否在字典的键中  
    try:  
        choice = int(inss)  
        if 1 <= choice <= 8:  
            pcap_file_dir_list = pcap_file_dir_dict[choice]  
        else:  
            print(f"输入的数字不在有效范围内(1-{len(pcap_file_dir_dict)})")  
            exit()
    except ValueError:  
        print("输入的不是一个有效的数字") 
        exit()

    


    big_win_time = [60]  
    small_win_time = [1]

    AAAA = r'Z:\BA0725'

    n_flag = 'zsc'
    
    for big_win, small_win in itertools.product(big_win_time, small_win_time):
        # if big_win == 60 and small_win == 2:
        #     continue
        if big_win > small_win:
            print("")
            print("")
            print(f'!!!  #_# *****************   \"big_winsize: {big_win}  small_winsize: {small_win}\"   ***************** #_#  !!!')
            print("")

            # pcap_file_dir_list = pcap_file_dir_list_5

            for pcap_file_dir in pcap_file_dir_list:

                if '+0' in pcap_file_dir:
                    label = 0
                else:
                    label = 1

                Tint2(pcap_file_dir, label, big_win, small_win, AAAA, n_flag)
        